Source code for hysop.symbolic.spectral

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sympy as sm
import numpy as np

from hysop.constants import BoundaryCondition, BoundaryExtension, TransformType
from hysop.tools.htypes import check_instance, to_tuple, first_not_None
from hysop.tools.sympy_utils import Expr, Symbol, Dummy, subscript
from hysop.tools.spectral_utils import SpectralTransformUtils as STU
from hysop.symbolic import SpaceSymbol
from hysop.symbolic.array import SymbolicBuffer
from hysop.symbolic.field import (
    FieldExpressionBuilder,
    FieldExpressionI,
    TensorBase,
    SymbolicField,
    AppliedSymbolicField,
)
from hysop.symbolic.frame import SymbolicFrame
from hysop.fields.continuous_field import Field, ScalarField, TensorField
from hysop.tools.spectral_utils import SpectralTransformUtils


[docs] class WaveNumberIndex(sm.Symbol): def __new__(cls, axis): obj = super().__new__(cls, f"i{axis}") obj.axis = axis obj._axes = None obj._real_index = None return obj def __init__(self, axis): super().__init__()
[docs] def bind_axes(self, axes): assert (self._axes is None) or (axes == self._axes) dim = len(axes) from hysop.symbolic import local_indices_symbols self._axes = axes self._real_index = local_indices_symbols[dim - 1 - axes.index(self.axis)]
@property def real_index(self): if self._real_index is None: msg = "No axes bound yet !" raise RuntimeError(msg) return self._real_index
[docs] class WaveNumber(Dummy): """Wave number symbol for SpectralTransform derivatives (and integrals).""" __transform2str = { TransformType.FFT: "c2c", TransformType.RFFT: "r2c", TransformType.DCT_I: "c1", TransformType.DCT_II: "c2", TransformType.DCT_III: "c3", TransformType.DCT_IV: "c4", TransformType.DST_I: "s1", TransformType.DST_II: "s2", TransformType.DST_III: "s3", TransformType.DST_IV: "s4", TransformType.IFFT: "c2c", TransformType.IRFFT: "r2c", TransformType.IDCT_I: "c1", TransformType.IDCT_II: "c3", TransformType.IDCT_III: "c2", TransformType.IDCT_IV: "c4", TransformType.IDST_I: "s1", TransformType.IDST_II: "s3", TransformType.IDST_III: "s2", TransformType.IDST_IV: "s4", } __wave_numbers = {} def __new__(cls, axis, transform, exponent, **kwds): check_instance(transform, TransformType) check_instance(axis, int, minval=0) check_instance(exponent, int, minval=1) if transform is TransformType.NONE: return None if exponent == 0: return 1 key = (transform, axis, exponent) if key in cls.__wave_numbers: return cls.__wave_numbers[key] tr_str = cls.__transform2str[transform] if len(tr_str) == 2: tr_pstr = tr_str[0] + subscript(int(tr_str[1])) else: tr_pstr = tr_str name = f"k{axis}_{tr_str}" pretty_name = "k" + subscript(axis) + "_" + tr_pstr if exponent < 0: name = "i" + name pretty_name = "i" + pretty_name exponent = -exponent if exponent > 1: name += f"__{exponent}" pretty_name += f"__{exponent}" obj = super().__new__(cls, name=name, pretty_name=pretty_name, **kwds) obj._axis = int(axis) obj._transform = transform obj._exponent = int(exponent) cls.__wave_numbers[key] = obj return obj def __init__(self, axis, transform, exponent, **kwds): super().__init__(name=None, pretty_name=None, **kwds) @property def axis(self): return self._axis @property def transform(self): return self._transform @property def exponent(self): return self._exponent @property def is_real(self): tr = self._transform exp = self._exponent is_real = STU.is_R2R(tr) is_real |= (not STU.is_R2R(tr)) and (exp % 2 == 0) return is_real @property def is_complex(self): tr = self._transform exp = self._exponent return (not STU.is_R2R(tr)) and (exp % 2 != 0)
[docs] def pow(self, exponent): exponent *= self.exponent return WaveNumber(axis=self.axis, transform=self.transform, exponent=exponent)
[docs] def indexed_buffer(self, name=None): name = first_not_None(name, self.name) buf = SymbolicBuffer(name=name, memory_object=None) idx = WaveNumberIndex(self.axis) obj = buf[idx] obj.Wn = self return obj
def __eq__(self, other): if not isinstance(other, WaveNumber): return NotImplemented eq = self.axis == other.axis eq &= self.transform == other.transform eq &= self.exponent == other.exponent return eq def __hash__(self): return hash((self.axis, self.transform, self.exponent))
[docs] class AppliedSpectralTransform(AppliedSymbolicField): """ An applied spectral transform. """
[docs] def short_description(self): ss = "{}(field={}, axes={}, is_forward={}, transforms=[{}])" return ss.format( self.__class__.__name__, self.field.pretty_name, self.transformed_axes, "1" if self.is_forward else "0", self.format_transforms(), )
[docs] def long_description(self): ss = """ == {} == *field: {} *transformed_axes: {} *spatial_axes: {} *is_forward: {} *transforms: {} *freq_vars: {} *space_vars: {} *all_vars: {} *wave_numbers: {} """ return ss.format( self.__class__.__name__, self.field.short_description(), self.transformed_axes, self.spatial_axes, self.is_forward, self.transforms, self.space_vars, self.freq_vars, self.all_vars, self.wave_numbers, )
[docs] def format_transforms(self): transforms = self.transforms return " x ".join(str(tr) for tr in transforms)
@property def field(self): return self._field @property def transformed_axes(self): return self._transformed_axes @property def spatial_axes(self): return self._spatial_axes @property def freq_vars(self): return self._freq_vars @property def space_vars(self): return self._space_vars @property def all_vars(self): return self._all_vars @property def frame(self): return self._frame @property def lboundaries(self): return self._field.lboundaries @property def rboundaries(self): return self._field.rboundaries @property def domain(self): return self._field.domain @property def dtype(self): return self._field.dtype @property def transforms(self): return self._transforms @property def wave_numbers(self): return self._wave_numbers @property def is_forward(self): return self._is_forward # SYMPY INTERNALS ################ @property def is_number(self): return False @property def free_symbols(self): return set(self._all_vars) def _eval_derivative(self, v): if v in self._freq_vars: i = self._all_vars.index(v) return self._wave_numbers[i] * self return sm.Derivative(self, v) def _hashable_content(self): """See sympy.core.basic.Basic._hashable_content()""" hc = super()._hashable_content() hc += (self.__class__,) return hc def __hash__(self): h = super().__hash__() for hc in (self.__class__,): h ^= hash(h) return h
[docs] def __eq__(self, other): "Fix sympy v1.2 eq" eq = super().__eq__(other) if eq is not True: return eq eq &= self.__class__ is other.__class__ return eq
[docs] def __ne__(self, other): "Fix sympy v1.2 neq" return not (self == other)
###################################
[docs] class SpectralTransform(SymbolicField): """ A single spectral transform that may be applied. This object can also be used as am sympy expression (and a FieldExpression). This expression carries datatype and boundary conditions. """ def __new__(cls, field, axes=None, forward=True): if isinstance(field, TensorField): T = field.new_empty_array() wave_numbers = () for idx, f in field.nd_iter(): T[idx] = cls(field=f, axes=axes, forward=forward) wave_numbers += T[idx].wave_numbers T = T.view(TensorBase) T.frame = T[0].frame return T dim = field.dim check_instance(field, ScalarField) axes = to_tuple(first_not_None(axes, range(field.dim))) check_instance(axes, tuple, values=int, minval=0, maxval=dim - 1, minsize=1) transformed_axes = tuple(sorted(set(axes))) spatial_axes = tuple(sorted(set(range(field.dim)) - set(axes))) frame = field.domain.frame freq_vars = tuple(frame.freqs[dim - 1 - i] for i in transformed_axes[::-1]) space_vars = tuple(frame.coords[dim - 1 - i] for i in spatial_axes[::-1]) all_vars = () for i in range(dim): if i in transformed_axes: all_vars += (frame.freqs[dim - 1 - i],) else: all_vars += (frame.coords[dim - 1 - i],) all_vars = all_vars[::-1] transforms = SpectralTransformUtils.transforms_from_field( field, transformed_axes=transformed_axes ) for i in range(frame.dim): assert (transforms[i] is TransformType.NONE) ^ (i in transformed_axes) wave_numbers = cls.generate_wave_numbers(transforms)[::-1] if not forward: transforms = SpectralTransformUtils.get_inverse_transforms(*transforms) frame = SymbolicFrame(dim=field.dim, freq_axes=transformed_axes) assert frame.coords == all_vars obj = super().__new__(cls, field=field, bases=(AppliedSpectralTransform,)) obj._field = field obj._transformed_axes = transformed_axes obj._spatial_axes = spatial_axes obj._freq_vars = freq_vars obj._space_vars = space_vars obj._is_forward = forward obj._all_vars = all_vars obj._transforms = transforms obj._wave_numbers = wave_numbers obj._frame = frame return obj(*all_vars)
[docs] @classmethod def generate_wave_numbers(cls, transforms): return SpectralTransformUtils.generate_wave_numbers(*transforms)
def _hashable_content(self): """See sympy.core.basic.Basic._hashable_content()""" hc = super()._hashable_content() hc += (self._transformed_axes, self._is_forward) return hc
[docs] def __hash__(self): "Fix sympy v1.2 hashes" h = super().__hash__() for hc in (self._transformed_axes, self._is_forward): h ^= hash(hc) return h
[docs] def __eq__(self, other): "Fix sympy v1.2 eq" eq = super().__eq__(other) if eq is not True: return eq for lhc, rhc in zip( (self._transformed_axes, self._is_forward), (other._transformed_axes, other._is_forward), ): eq &= lhc == rhc return eq
[docs] def __ne__(self, other): "Fix sympy v1.2 neq" return not (self == other)
if __name__ == "__main__": from hysop.tools.sympy_utils import sstr from hysop import Box from hysop.constants import BoxBoundaryCondition from hysop.defaults import VelocityField, VorticityField from hysop.symbolic.field import laplacian, curl from hysop.symbolic.relational import Assignment from hysop.tools.sympy_utils import Greak dim = 3 d = Box( dim=dim, lboundaries=( BoxBoundaryCondition.SYMMETRIC, BoxBoundaryCondition.OUTFLOW, BoxBoundaryCondition.SYMMETRIC, ), rboundaries=( BoxBoundaryCondition.SYMMETRIC, BoxBoundaryCondition.OUTFLOW, BoxBoundaryCondition.OUTFLOW, ), ) U = VelocityField(domain=d) W = VorticityField(velocity=U) psi = W.field_like(name="psi", pretty_name=Greak[23]) W_hat = SpectralTransform(W, forward=True) U_hat = SpectralTransform(U, forward=False) psi_hat = SpectralTransform(psi) eqs = laplacian(psi_hat, psi_hat.frame) - W_hat sol = sm.solve(eqs, psi_hat.tolist()) sol = curl(psi_hat, psi_hat.frame).xreplace(sol) print("VELOCITY") print(U.short_description()) print() print("VORTICITY") print(W.short_description()) print() print("W_hat") print(W_hat) print() print("U_hat") print(U_hat) print() print("Psi_hat") print(psi_hat) print() for eq in Assignment.assign(U_hat, sol): eq, trs, wn = SpectralTransformUtils.parse_expression(eq) print() print(eq) for tr in trs: print(tr.short_description()) print(wn)